{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Semantic Segmentation Demo\n",
    "\n",
    "This is a notebook for running the benchmark semantic segmentation network from the the [ADE20K MIT Scene Parsing Benchchmark](http://sceneparsing.csail.mit.edu/).\n",
    "\n",
    "The code for this notebook is available here\n",
    "https://github.com/CSAILVision/semantic-segmentation-pytorch/tree/master/notebooks\n",
    "\n",
    "It can be run on Colab at this URL https://colab.research.google.com/github/CSAILVision/semantic-segmentation-pytorch/blob/master/notebooks/DemoSegmenter.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Environment Setup\n",
    "\n",
    "First, download the code and pretrained models if we are on colab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "# Colab-specific setup\n",
    "!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit \n",
    "pip install yacs 2>&1 >> install.log\n",
    "git init 2>&1 >> install.log\n",
    "git remote add origin https://github.com/CSAILVision/semantic-segmentation-pytorch.git 2>> install.log\n",
    "git pull origin master 2>&1 >> install.log\n",
    "DOWNLOAD_ONLY=1 ./demo_test.sh 2>> install.log"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and utility functions\n",
    "\n",
    "We need pytorch, numpy, and the code for the segmentation model.  And some utilities for visualizing the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# System libs\n",
    "import os, csv, torch, numpy, scipy.io, PIL.Image, torchvision.transforms\n",
    "# Our libs\n",
    "from mit_semseg.models import ModelBuilder, SegmentationModule\n",
    "from mit_semseg.utils import colorEncode\n",
    "\n",
    "colors = scipy.io.loadmat('data/color150.mat')['colors']\n",
    "names = {}\n",
    "with open('data/object150_info.csv') as f:\n",
    "    reader = csv.reader(f)\n",
    "    next(reader)\n",
    "    for row in reader:\n",
    "        names[int(row[0])] = row[5].split(\";\")[0]\n",
    "\n",
    "def visualize_result(img, pred, index=None):\n",
    "    # filter prediction class if requested\n",
    "    if index is not None:\n",
    "        pred = pred.copy()\n",
    "        pred[pred != index] = -1\n",
    "        print(f'{names[index+1]}:')\n",
    "        \n",
    "    # colorize prediction\n",
    "    pred_color = colorEncode(pred, colors).astype(numpy.uint8)\n",
    "\n",
    "    # aggregate images and save\n",
    "    im_vis = numpy.concatenate((img, pred_color), axis=1)\n",
    "    display(PIL.Image.fromarray(im_vis))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the segmentation model\n",
    "\n",
    "Here we load a pretrained segmentation model.  Like any pytorch model, we can call it like a function, or examine the parameters in all the layers.\n",
    "\n",
    "After loading, we put it on the GPU.  And since we are doing inference, not training, we put the model in eval mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Network Builders\n",
    "net_encoder = ModelBuilder.build_encoder(\n",
    "    arch='resnet50dilated',\n",
    "    fc_dim=2048,\n",
    "    weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth')\n",
    "net_decoder = ModelBuilder.build_decoder(\n",
    "    arch='ppm_deepsup',\n",
    "    fc_dim=2048,\n",
    "    num_class=150,\n",
    "    weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth',\n",
    "    use_softmax=True)\n",
    "\n",
    "crit = torch.nn.NLLLoss(ignore_index=-1)\n",
    "segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)\n",
    "segmentation_module.eval()\n",
    "segmentation_module.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load test data\n",
    "\n",
    "Now we load and normalize a single test image.  Here we use the commonplace convention of normalizing the image to a scale for which the RGB values of a large photo dataset would have zero mean and unit standard deviation.  (These numbers come from the imagenet dataset.)  With this normalization, the limiiting ranges of RGB values are within about (-2.2 to +2.7)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load and normalize one image as a singleton tensor batch\n",
    "pil_to_tensor = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize(\n",
    "        mean=[0.485, 0.456, 0.406], # These are RGB mean+std values\n",
    "        std=[0.229, 0.224, 0.225])  # across a large photo dataset.\n",
    "])\n",
    "pil_image = PIL.Image.open('ADE_val_00001519.jpg').convert('RGB')\n",
    "img_original = numpy.array(pil_image)\n",
    "img_data = pil_to_tensor(pil_image)\n",
    "singleton_batch = {'img_data': img_data[None].cuda()}\n",
    "output_size = img_data.shape[1:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the Model\n",
    "\n",
    "Finally we just pass the test image to the segmentation model.\n",
    "\n",
    "The segmentation model is coded as a function that takes a dictionary as input, because it wants to know both the input batch image data as well as the desired output segmentation resolution.  We ask for full resolution output.\n",
    "\n",
    "Then we use the previously-defined visualize_result function to render the segmentation map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Run the segmentation at the highest resolution.\n",
    "with torch.no_grad():\n",
    "    scores = segmentation_module(singleton_batch, segSize=output_size)\n",
    "    \n",
    "# Get the predicted scores for each pixel\n",
    "_, pred = torch.max(scores, dim=1)\n",
    "pred = pred.cpu()[0].numpy()\n",
    "visualize_result(img_original, pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Showing classes individually\n",
    "\n",
    "To see which colors are which, here we visualize individual classes, one at a time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Top classes in answer\n",
    "predicted_classes = numpy.bincount(pred.flatten()).argsort()[::-1]\n",
    "for c in predicted_classes[:15]:\n",
    "    visualize_result(img_original, pred, c)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}